We are building a model from this week’s #tidytuesday data set on volcano eruptions. What we are looking for will be building a multiclass random forest classifier to predict the type of volcano based on other volcano characteristics like latitude, longitude, tectonic setting, etc.
volcano_raw <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-05-12/volcano.csv')
volcano_raw %>%
count(primary_volcano_type, sort = TRUE)
## # A tibble: 26 x 2
## primary_volcano_type n
## <chr> <int>
## 1 Stratovolcano 353
## 2 Stratovolcano(es) 107
## 3 Shield 85
## 4 Volcanic field 71
## 5 Pyroclastic cone(s) 70
## 6 Caldera 65
## 7 Complex 46
## 8 Shield(s) 33
## 9 Submarine 27
## 10 Lava dome(s) 26
## # … with 16 more rows
We only have <1000 data points yet 26 types, so we won’t be looking at all the different types or compress the similar types. We will only consider 3 major types: - stratovolcano - shield volcano - the rest (other)
volcano_df <- volcano_raw %>%
transmute(
volcano_type = case_when(
str_detect(primary_volcano_type, "Stratovolcano") ~ "Stratovolcano",
str_detect(primary_volcano_type, "Shield") ~ "Shield",
TRUE ~ "Other"
),
volcano_number, latitude, longitude, elevation,
tectonic_settings, major_rock_1
) %>%
mutate_if(is.character, factor)
volcano_df %>%
count(volcano_type, sort = TRUE)
## # A tibble: 3 x 2
## volcano_type n
## <fct> <int>
## 1 Stratovolcano 461
## 2 Other 379
## 3 Shield 118
When we have spatial information, it’s always good to display it in a map.
world <- map_data("world")
ggplot() +
geom_map(data = world, map = world,
aes(long, lat, map_id = region),
color = "white", fill = "gray50", alpha = 0.2) +
geom_point(data = volcano_df,
aes(longitude, latitude, color = volcano_type),
alpha = 0.8)
library(tidymodels)
volcano_boot <- bootstraps(volcano_df, times = 500)
volcano_boot
## # Bootstrap sampling
## # A tibble: 500 x 2
## splits id
## <list> <chr>
## 1 <split [958/348]> Bootstrap001
## 2 <split [958/353]> Bootstrap002
## 3 <split [958/362]> Bootstrap003
## 4 <split [958/352]> Bootstrap004
## 5 <split [958/356]> Bootstrap005
## 6 <split [958/354]> Bootstrap006
## 7 <split [958/356]> Bootstrap007
## 8 <split [958/348]> Bootstrap008
## 9 <split [958/353]> Bootstrap009
## 10 <split [958/352]> Bootstrap010
## # … with 490 more rows
themis package is typically used for imbalanced datasets, particularly using step_smote. Some features have too many values for such a small dataset and thus, we need to apply step_other for these variables.
library(themis)
volcano_rec <- recipe(volcano_type ~ ., data = volcano_df) %>%
update_role(volcano_number, new_role = "Id") %>%
step_other(tectonic_settings) %>%
step_other(major_rock_1) %>%
step_dummy(tectonic_settings, major_rock_1) %>%
step_zv(all_predictors()) %>%
step_normalize(all_predictors()) %>%
step_smote(volcano_type)
volcano_prep <- prep(volcano_rec)
rf_spec <- rand_forest(trees = 1000) %>%
set_mode("classification") %>%
set_engine("ranger")
volcano_wf <- workflow() %>%
add_recipe(volcano_rec) %>%
add_model(rf_spec)
volcano_wf
## ══ Workflow ═══════════════════════════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: rand_forest()
##
## ── Preprocessor ───────────────────────────────────────────────────────────────────────────
## 6 Recipe Steps
##
## ● step_other()
## ● step_other()
## ● step_dummy()
## ● step_zv()
## ● step_normalize()
## ● step_smote()
##
## ── Model ──────────────────────────────────────────────────────────────────────────────────
## Random Forest Model Specification (classification)
##
## Main Arguments:
## trees = 1000
##
## Computational engine: ranger
For any individual bootstrap, each recipe is evaluated and then the model as well. Thereafter, predictions are made.
volcano_res <- fit_resamples(
volcano_wf,
resamples = volcano_boot,
control = control_resamples(save_pred = TRUE,
verbose = TRUE)
)
How good our model is?
volcano_res %>%
collect_metrics()
## # A tibble: 2 x 5
## .metric .estimator mean n std_err
## <chr> <chr> <dbl> <int> <dbl>
## 1 accuracy multiclass 0.653 25 0.00332
## 2 roc_auc hand_till 0.791 25 0.00378
volcano_res %>%
collect_predictions() %>%
conf_mat(volcano_type, .pred_class)
## Truth
## Prediction Other Shield Stratovolcano
## Other 1938 328 746
## Shield 274 552 211
## Stratovolcano 1273 212 3247
volcano_res %>%
collect_predictions() %>%
group_by(id) %>%
ppv(volcano_type, .pred_class) %>%
ggplot(aes(.estimate)) + geom_histogram(bins = 10)
Now, we are fitting the model into the original data to see features importance.
library(vip)
rf_spec %>%
set_engine("ranger", importance = "permutation") %>%
fit(volcano_type ~ .,
data = juice(volcano_prep) %>%
select(-volcano_number) %>%
janitor::clean_names()) %>%
vip(geom = "point")
From this plot above, longitude and latitude are the biggest factors to affect volcano. Major rock basalt picro basalt is the next biggest impact on the prediction model. Given longitude and latitude as the most important features in this model, we can display this in a map again to see where our model predicts better. Furthermore, we can facet_wrap according to volcano type to observe which type and where they’re predicted better by the model.
volcano_pred <- volcano_res %>%
collect_predictions() %>%
mutate(correct = volcano_type == .pred_class) %>%
left_join(volcano_df %>% mutate(.row = row_number()))
ggplot() +
geom_map(data = world, map = world,
aes(long, lat, map_id = region),
color = "white", fill = "gray50", alpha = 0.2) +
stat_summary_hex(data = volcano_pred,
aes(longitude, latitude, z = correct),
fun = "mean",
alpha = 0.5,
size = 0.5, bins = 60) +
scale_fill_gradient(high = "cyan3", labels = scales::percent) +
labs(x = "longitude", y = "latitude", fill = "Percent classified\ncorrectly")
ggplot() +
geom_map(data = world, map = world,
aes(long, lat, map_id = region),
color = "white", fill = "gray50", alpha = 0.2) +
stat_summary_hex(data = volcano_pred,
aes(longitude, latitude, z = correct),
fun = "mean",
alpha = 0.5,
size = 0.5, bins = 60) +
facet_wrap(~ volcano_type) +
scale_fill_gradient(high = "cyan3", labels = scales::percent) +
labs(x = "longitude", y = "latitude", fill = "Percent classified\ncorrectly")